{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training of smallNORB model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision as tv\n",
    "\n",
    "from tqdm import tqdm\n",
    "from sklearn.manifold import MDS\n",
    "from pytorch_extras import RAdam, SingleCycleScheduler\n",
    "from deps.small_norb.smallnorb.dataset import SmallNORBDataset\n",
    "from deps.torch_train_test_loop.torch_train_test_loop import LoopComponent, TrainTestLoop\n",
    "\n",
    "from models import SmallNORBClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = 'cuda:0'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smallnorb = SmallNORBDataset(dataset_root='.data/smallnorb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SmallNORBTorchDataset(torch.utils.data.Dataset):\n",
    "\n",
    "    def __init__(self, data, categories, preprocessing):\n",
    "        self.data = data\n",
    "        self.categories = categories\n",
    "        self.preprocess = preprocessing\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        images = np.stack((self.data[i].image_lt, self.data[i].image_rt), axis=-1)  # [96, 96, 2]\n",
    "        images = self.preprocess(images).cuda(device=DEVICE)  # [2, 96, 96]\n",
    "        category = torch.tensor(self.data[i].category, dtype=torch.long).cuda(device=DEVICE)\n",
    "        return { 'images': images, 'category': category, }\n",
    "\n",
    "random_crops = tv.transforms.Compose([\n",
    "    tv.transforms.ToPILImage(),\n",
    "    tv.transforms.RandomCrop(size=96, padding=16, padding_mode='edge'),\n",
    "    tv.transforms.ToTensor()\n",
    "])\n",
    "\n",
    "# Normally we would divide train set into train/valid splits; we don't here to match other papers.\n",
    "trn_ds = SmallNORBTorchDataset(smallnorb.data['train'], smallnorb.categories, random_crops)\n",
    "val_ds = SmallNORBTorchDataset(smallnorb.data['test'], smallnorb.categories, random_crops)\n",
    "tst_ds = SmallNORBTorchDataset(smallnorb.data['test'], smallnorb.categories, tv.transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LoopMain(LoopComponent):\n",
    "\n",
    "    def __init__(self, n_classes, device, pct_warmup=0.1, mixup=(0.2, 0.2)):\n",
    "        self.n_classes, self.device, self.pct_warmup = (n_classes, device, pct_warmup)\n",
    "        self.mixup_dist = torch.distributions.Beta(torch.tensor(mixup[0]), torch.tensor(mixup[1]))\n",
    "        self.to_onehot = torch.eye(self.n_classes, device=self.device)\n",
    "        self.saved_data = []\n",
    "\n",
    "    def on_train_begin(self, loop):\n",
    "        n_iters = len(loop.train_data) * loop.n_epochs\n",
    "        loop.optimizer = RAdam(loop.model.parameters(), lr=5e-4)\n",
    "        loop.scheduler = SingleCycleScheduler(\n",
    "            loop.optimizer, loop.n_optim_steps, frac=self.pct_warmup, min_lr=1e-5)\n",
    "\n",
    "    def on_grads_reset(self, loop):\n",
    "        loop.model.zero_grad()\n",
    "\n",
    "    def on_forward_pass(self, loop):\n",
    "        images, category = loop.batch['images'], loop.batch['category']\n",
    "        target_probs = self.to_onehot[category]\n",
    "\n",
    "        if loop.is_training:\n",
    "            r = self.mixup_dist.sample([len(images)]).to(device=images.device)\n",
    "            idx = torch.randperm(len(images))\n",
    "            images = images.lerp(images[idx], r[:, None, None, None])\n",
    "            target_probs = target_probs.lerp(target_probs[idx], r[:, None])\n",
    "\n",
    "        pred_scores, _, _ = model(images)\n",
    "        _, pred_ids = pred_scores.max(-1)\n",
    "        accuracy = (pred_ids == category).float().mean()\n",
    "\n",
    "        loop.pred_scores, loop.target_probs, loop.accuracy = (pred_scores, target_probs, accuracy)\n",
    "\n",
    "    def on_loss_compute(self, loop):\n",
    "        losses = -loop.target_probs * F.log_softmax(loop.pred_scores, dim=-1)  # CE\n",
    "        loop.loss = losses.sum(dim=-1).mean()  # sum of classes, mean of batch\n",
    "\n",
    "    def on_backward_pass(self, loop):\n",
    "        loop.loss.backward()\n",
    "\n",
    "    def on_optim_step(self, loop):\n",
    "        loop.optimizer.step()\n",
    "        loop.scheduler.step()\n",
    "\n",
    "    def on_batch_end(self, loop):\n",
    "        self.saved_data.append({\n",
    "            'n_samples': len(loop.batch['images']),\n",
    "            'epoch_desc': loop.epoch_desc,\n",
    "            'epoch_num': loop.epoch_num,\n",
    "            'epoch_frac': loop.epoch_num + loop.batch_num / loop.n_batches,\n",
    "            'batch_num' : loop.batch_num,\n",
    "            'accuracy': loop.accuracy.item(),\n",
    "            'loss': loop.loss.item(),\n",
    "            'lr': loop.optimizer.param_groups[0]['lr'],\n",
    "            'momentum': loop.optimizer.param_groups[0]['betas'][0],\n",
    "        })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LoopProgressBar(LoopComponent):\n",
    "\n",
    "    def __init__(self, item_names=['loss', 'accuracy']):\n",
    "        self.item_names = item_names\n",
    "\n",
    "    def on_epoch_begin(self, loop):\n",
    "        self.total, self.count = ({ name: 0.0 for name in self.item_names }, 0)\n",
    "        self.pbar = tqdm(total=loop.n_batches, desc=f\"{loop.epoch_desc} epoch {loop.epoch_num}\")\n",
    "\n",
    "    def on_batch_end(self, loop):\n",
    "        n = len(loop.batch['images'])\n",
    "        self.count += n\n",
    "        for name in self.item_names:\n",
    "            self.total[name] += getattr(loop, name).item() * n\n",
    "        self.pbar.update(1)\n",
    "        if (not loop.is_training):\n",
    "            self.pbar.set_postfix(self.mean)\n",
    "\n",
    "    def on_epoch_end(self, loop):\n",
    "        self.pbar.close()\n",
    "\n",
    "    @property\n",
    "    def mean(self): return {\n",
    "        f'mean_{name}': self.total[name] / self.count for name in self.item_names\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize and train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Seed RNG for replicability. Run at least a few times without seeding to measure performance.\n",
    "# torch.manual_seed(<type an int here>)\n",
    "\n",
    "# Make iterators for each split, with random shuffling in train set.\n",
    "trn_itr = torch.utils.data.DataLoader(trn_ds, batch_size=20, shuffle=True)\n",
    "val_itr = torch.utils.data.DataLoader(val_ds, batch_size=20, shuffle=False)\n",
    "tst_itr = torch.utils.data.DataLoader(tst_ds, batch_size=20, shuffle=False)\n",
    "\n",
    "# Initialize model.\n",
    "n_classes = len(trn_ds.categories)\n",
    "model = SmallNORBClassifier(n_objs=n_classes, n_parts=64, d_chns=64)\n",
    "model = model.cuda(device=DEVICE)\n",
    "print('Total number of parameters: {:,}'.format(sum(np.prod(p.shape) for p in model.parameters())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train model\n",
    "loop = TrainTestLoop(model, [LoopMain(n_classes, DEVICE), LoopProgressBar()], trn_itr, val_itr)\n",
    "loop.train(n_epochs=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loop.test(tst_itr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize capsule trajectories as we change azimuth/elevation of test images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_obj_sequence(category, instance, vary='azimuth'):\n",
    "    const_attr = 'elevation' if (vary == 'azimuth') else 'azimuth'\n",
    "    samples = [\n",
    "        s for s in smallnorb.data['test'] if\n",
    "        (s.category, s.instance, s.lighting, getattr(s, const_attr)) == (category, instance, 0, 0)\n",
    "    ]\n",
    "    ds = SmallNORBTorchDataset(samples, smallnorb.categories, tv.transforms.ToTensor())\n",
    "    with torch.no_grad():\n",
    "        a, mu, sig2 = loop.model(torch.stack([sample['images'] for sample in ds], dim=0))\n",
    "    mu = mu[:, category, :, :].cpu()\n",
    "    return ds.data, mu\n",
    "\n",
    "data, mu = get_obj_sequence(category=2, instance=0, vary='elevation')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(ncols=len(data), figsize=(len(data), 1))\n",
    "for sample, axis in zip(data, axes):\n",
    "    axis.imshow(sample.image_lt, cmap='gray', vmin=0, vmax=255)\n",
    "    axis.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axis = plt.subplots(figsize=(3, 3))\n",
    "mds = MDS(n_components=2)\n",
    "x = mu.contiguous().view(-1, mu.shape[-1]).numpy()\n",
    "x = mds.fit_transform(x)\n",
    "x = x.reshape(mu.shape[0], 4, 2)\n",
    "blue = '#1f77b4'\n",
    "for i in range(len(x)):\n",
    "    vert = np.concatenate((x[i], x[i, :1, :]), axis=0)\n",
    "    alpha = 0.2 + (i + 1.0) / len(x) * 0.8\n",
    "    axis.plot(vert[:, 0], vert[:, 1], color=blue, alpha=alpha)\n",
    "    if (i == 0) or (i + 1 == len(x)):\n",
    "        axis.scatter(vert[0, 0], vert[0, 1], color=blue,\n",
    "                     facecolor=('white' if i == 0 else blue))\n",
    "    axis.set(xticks=[], yticks=[])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
